import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import argparse
import seaborn as sns

def plot(**kw):

    T = kw['T']
    N = kw['N']
    K = kw['K']
    d = kw['d']
    lam = kw['lam']
    seeds = kw['seed']
    algorithms = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    # C = kw['C']
    
    plt.rcParams.update({'font.size': 14}) 
    fig = plt.figure(figsize=(8,5.5))

    # t = range(0, T+1)
    t = np.array(range(0, T+1))

    colors = sns.color_palette("deep", n_colors=len(algorithms))
    
    for algorithm in algorithms:

        regret_results = []

        for seed in seeds:

            with open(f'./results-movielens/checkpoint/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}.npy', 'rb') as f:
                regret = np.load(f)    
            regret_results.append(regret)
    
        regret_results = np.array(regret_results)
        mean_regret = np.concatenate((np.zeros(1), np.mean(regret_results, axis=0)))
        std_regret = np.concatenate((np.zeros(1), np.std(regret_results, axis=0)))
        se = std_regret / np.sqrt(len(seeds))

        error_x = np.arange(500, 5001, 500)
        error_y = mean_regret[error_x]    
        error_yerr = se[error_x]

        if algorithm == 'C3UCB':
            plt.plot(mean_regret[:T+1], label=r'C$^3$-UCB', color=colors[0])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[0], alpha=0.7, capsize=5)
        elif algorithm == 'UCBCCA':
            plt.plot(mean_regret[:T+1], label=r'UCB-CCA', color=colors[2])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[2], alpha=0.7, capsize=5)
        elif algorithm == 'CLogUCB':
            plt.plot(mean_regret[:T+1], label=r'CLogUCB', color=colors[1])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[1], alpha=0.7, capsize=5)
        elif algorithm == 'UCBCLB':
            plt.plot(mean_regret[:T+1], label=r'UCB-CLB', color=colors[3])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[3], alpha=0.7, capsize=5)    
    
    plt.title(fr'$N$={N}, $K$={K}, $d$={d}', fontsize=20)
    plt.xlabel(r'Round ($t$)',fontsize=18)
    plt.ylabel('Cumulative Regret', fontsize=18)
    plt.xlim(0, T)
    plt.grid(True)
    plt.legend(loc='upper right')
    
    plt.savefig(f'./results-movielens/regret_images/Compare_Logistic_T={T}_N={N}_K={K}_d={d}_B={B}.pdf')

def plotK(**kw):

    T = kw['T']
    N = kw['N']
    Ks = kw['K']
    d = kw['d']
    lam = kw['lam']
    seeds = kw['seed']
    algorithm = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    # C = kw['C']
    
    plt.rcParams.update({'font.size': 14}) 
    fig = plt.figure(figsize=(9,6))

    # T0 = 5000
    # t = range(0, T0+1)

    colors = sns.color_palette("deep", n_colors=len(Ks))

    for idx, K in enumerate(Ks):

        regret_results = []

        for seed in seeds:

            with open(f'./results-movielens/checkpoint/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}.npy', 'rb') as f:
                regret = np.load(f)    
            regret_results.append(regret)
    
        regret_results = np.array(regret_results)
        mean_regret = np.concatenate((np.zeros(1), np.mean(regret_results, axis=0)))
        std_regret = np.concatenate((np.zeros(1), np.std(regret_results, axis=0)))
        se = std_regret / np.sqrt(len(seeds))

        error_x = np.arange(500, T+1, 500)
        error_y = mean_regret[error_x]    
        error_yerr = se[error_x]
        
        plt.plot(mean_regret[:T+1], label=f'K={K}', color=colors[idx])
        plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[idx], alpha=0.7, capsize=5)

    if algorithm == 'C3UCB':
        plt.title(fr'C$^3$-UCB, $N$={N}, $d$={d}', fontsize=20)
    elif algorithm == 'UCBCCA':
        plt.title(fr'UCB-CCA, $N$={N}, $d$={d}', fontsize=20)
    elif algorithm == 'CLogUCB':
        plt.title(fr'CLogUCB, $N$={N}, $d$={d}', fontsize=20)
    elif algorithm == 'UCBCLB':
        plt.title(fr'UCB-CLB, $N$={N}, $d$={d}', fontsize=20)

    # plt.title(fr'$N$={N}, $d$={d}', fontsize=20)
    plt.xlabel(r'Round ($t$)', fontsize=18)
    plt.ylabel('Cumulative Regret', fontsize=20)
    plt.xlim(0, T)
    # plt.ylim(0, 200)
    plt.grid(True)
    plt.legend(loc='upper right')
    
    plt.savefig(f'./results-movielens/regret_images/K_variants_alg={algorithm}_T={T}_N={N}_d={d}_B={B}.pdf')

def plotK2(**kw):

    T = kw['T']
    N = kw['N']
    Ks = kw['K']
    d = kw['d']
    lam = kw['lam']
    seeds = kw['seed']
    algorithm = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    C = kw['C']
    
    plt.rcParams.update({'font.size': 14}) 
    fig = plt.figure(figsize=(8,5.5))

    # x = np.arange(1, len(Ks) + 1)
    x = np.array(Ks)
    means = np.zeros(len(Ks))
    ses = np.zeros(len(Ks))

    for idx, K in enumerate(Ks):

        regret_results = []

        for seed in seeds:

            with open(f'./results-movielens/checkpoint/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}_C={C}.npy', 'rb') as f:
                regret = np.load(f)    
            regret_results.append(regret)
    
        regret_results = np.array(regret_results)
        mean_regret = np.concatenate((np.zeros(1), np.mean(regret_results, axis=0)))
        std_regret = np.concatenate((np.zeros(1), np.std(regret_results, axis=0)))
        se_regret = std_regret / np.sqrt(len(seeds))

        means[idx] = mean_regret[-1]
        ses[idx] = se_regret[-1]

    marker_color = "#2ca02c"
    error_color = "#ff7f0e"

    plt.plot(x, means, label='Avg. Regret', linewidth=2)
    # plt.errorbar(x, means, yerr=ses, fmt='o', capsize=5, label='Std Error')   
    plt.errorbar(
        x, means, yerr=ses, 
        fmt='o', capsize=5, label='Std Error', 
        color='orange', markersize=6, markerfacecolor=marker_color, markeredgecolor=marker_color, 
        ecolor=error_color,
        linestyle='None')
    
    def y_func(x):
        base = 1 / (1 + np.exp(1))
        return x * (base ** (x - 1)) * 5 * np.sqrt(5000)
    
    lower_vals = y_func(x)
    lower_vals = 0.2 * lower_vals

    plt.plot(x, lower_vals, linestyle='--', marker='o', color='red', markersize=4,  linewidth=1, label='Lower Bound')


    plt.title(fr'UCB-CLB with $T$={T}, $N$={N}, $d$={d}', fontsize=20)
    plt.xlabel(r'Cascade length ($K$)', fontsize=16)
    plt.ylabel(f'Cumulative Regret at round {T}', fontsize=16)
    # plt.xlim(0.8, len(Ks)+0.2)
    # plt.xticks(np.arange(1, len(Ks)+1, 1))
    # plt.xlim(0, len(Ks))
    plt.ylim(0, 160)
    # plt.grid(True)
    # plt.legend(loc='upper right')
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Mean', markerfacecolor=marker_color, markeredgecolor=marker_color, markersize=8),
        Line2D([0], [0], color=error_color, lw=2, label='Std Error')
    ]
    plt.legend(handles=legend_elements, fontsize=12)


    for spine in ['top', 'right']:
        plt.gca().spines[spine].set_visible(False)
    plt.tight_layout()
    
    plt.savefig(f'./results-movielens/regret_images/K_variants_T={T}_N={N}_d={d}.pdf')

def plot_time(**kw):

    T = kw['T']
    N = kw['N']
    K = kw['K']
    d = kw['d']
    lam = kw['lam']
    seeds = kw['seed']
    algorithms = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    # C = kw['C']
    
    plt.rcParams.update({'font.size': 14}) 
    fig = plt.figure(figsize=(8,5.5))

    # t = range(0, T+1)
    t = np.array(range(0, T+1))

    colors = sns.color_palette("deep", n_colors=len(algorithms))
    
    for algorithm in algorithms:

        time_results = []

        for seed in seeds:

            with open(f'./results-movielens/times/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}.npy', 'rb') as f:
                times = np.load(f)    
            times = times / 1000
            time_results.append(times)
    
        time_results = np.array(time_results) 
        mean_time = np.concatenate((np.zeros(1), np.mean(time_results, axis=0)))
        std_time = np.concatenate((np.zeros(1), np.std(time_results, axis=0)))
        se = std_time / np.sqrt(len(seeds))

        error_x = np.arange(500, 5001, 500)
        error_y = mean_time[error_x]    
        error_yerr = se[error_x]
        
        if algorithm == 'C3UCB':
            plt.plot(mean_time[:T+1], label=r'C$^3$-UCB', color=colors[0])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[0], alpha=0.7, capsize=5)
        elif algorithm == 'UCBCCA':
            plt.plot(mean_time[:T+1], label=r'UCB-CCA', color=colors[2])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[2], alpha=0.7, capsize=5)
        elif algorithm == 'CLogUCB':
            plt.plot(mean_time[:T+1], label=r'CLogUCB', color=colors[1])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[1], alpha=0.7, capsize=5)
        elif algorithm == 'UCBCLB':
            plt.plot(mean_time[:T+1], label=r'UCB-CLB', color=colors[3])
            plt.errorbar(error_x, error_y, yerr=error_yerr, fmt='o', color=colors[3], alpha=0.7, capsize=5)    
    
    plt.title(fr'$N$={N}, $K$={K}, $d$={d}', fontsize=20)
    plt.xlabel(r'Round ($t$)',fontsize=18)
    plt.ylabel('Time (seconds)', fontsize=18)
    plt.xlim(0, T)
    plt.grid(True)
    plt.legend(loc='upper right')
    
    plt.savefig(f'./results-movielens/regret_images/Running_time_Logistic_T={T}_N={N}_K={K}_d={d}_B={B}.pdf')

def main():

    parser = argparse.ArgumentParser()

    # Comparision of Algorithms
    parser.add_argument('--T', type=int, default=5000)
    parser.add_argument('--N', type=int, default=1642)
    parser.add_argument('--K', type=int, default=5)
    parser.add_argument('--dim', type=int, default=25)
    parser.add_argument('--lam', type=float, default=1.0)
    parser.add_argument('--seed', type=int, nargs='*', default=range(10))
    parser.add_argument('--alg', nargs='*', default=['C3UCB', 'CLogUCB', 'UCBCCA', 'UCBCLB']) 
    parser.add_argument('--oracle', default='optimal') # optimal
    parser.add_argument('--B', type=float, default=20.0)
    parser.add_argument('--C', type=float, default=1.0)

    args = parser.parse_args()
    
    T = args.T
    N = args.N
    K = args.K
    d = args.dim
    lam = args.lam
    seeds = args.seed
    algorithms = args.alg
    oracle_type = args.oracle
    B = args.B
    C = args.C

    kw = {'T': T, 'N': N, 'K': K, 'd': d, 
            'lam':lam, 'seed': seeds, 
            'algorithm': algorithms, 'oracle_type': oracle_type, 'B':B, 'C':C}
    
    # plot(**kw)
    plot_time(**kw)

def main_compare_K():

    parser = argparse.ArgumentParser()

    # Plot Various K
    parser.add_argument('--T', type=int, default=5000)
    parser.add_argument('--N', type=int, default=1642)
    parser.add_argument('--K', type=int, nargs='*', default=[1, 2, 3, 4, 5, 10, 15, 20])
    parser.add_argument('--dim', type=int, default=25)
    parser.add_argument('--lam', type=float, default=1.0)
    parser.add_argument('--seed', type=int, nargs='*', default=range(10))
    parser.add_argument('--alg', default='UCBCLB')
    parser.add_argument('--oracle', default='optimal')
    parser.add_argument('--B', type=float, default=20.0)
    parser.add_argument('--C', type=float, default=1.0)

    args = parser.parse_args()
    
    T = args.T
    N = args.N
    Ks = args.K
    dim = args.dim
    lam = args.lam
    seeds = args.seed
    algorithm = args.alg
    oracle_type = args.oracle
    B = args.B
    C = args.C

    kw = {'T': T, 'N': N, 'K': Ks, 'd': dim, 
          'lam':lam, 'seed': seeds, 
          'algorithm': algorithm, 'oracle_type': oracle_type, 'B':B, 'C':C}
    
    plotK(**kw)

if __name__ == "__main__":
    # main()

    main_compare_K()